{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright (c) MONAI Consortium  \n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    "you may not use this file except in compliance with the License.  \n",
    "You may obtain a copy of the License at  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  \n",
    "Unless required by applicable law or agreed to in writing, software  \n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    "See the License for the specific language governing permissions and  \n",
    "limitations under the License.\n",
    "\n",
    "# ThreadBuffer Performance\n",
    "\n",
    "This notebook demonstrates the use of `ThreadBuffer` to generate batches of data asynchronously from the training thread.\n",
    "\n",
    "Under certain circumstances the main thread can be busy with the training operations, that is interacting with GPU memory and invoking CUDA operations, which is independent of batch generation operations. If the time taken to generate a batch is significant compared to the time taken to train the network for an iteration, and assuming operations can be done in parallel given the limitations of the GIL or other factors, this should speed up the whole training process. The efficiency gains will be relative to the proportion of these two times, so if batch generation is lengthy but training is very fast then very little parallel computation is possible.\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/acceleration/threadbuffer_performance.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup Environment\n",
    "\n",
    "The current MONAI main branch must be installed for this feature (as of release 0.9.1), skip this step if already installed:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q \"monai-weekly[tqdm]\""
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import monai\n",
    "import numpy as np\n",
    "import torch\n",
    "from monai.data import DataLoader, Dataset, ThreadBuffer, create_test_image_2d\n",
    "from monai.losses import Dice\n",
    "from monai.networks.nets import UNet\n",
    "from monai.transforms import EnsureChannelFirstd, Compose, MapTransform\n",
    "\n",
    "monai.utils.set_determinism(seed=0)\n",
    "\n",
    "monai.config.print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The data pipeline is given here which creates random 2D segmentation training pairs. It is artificially slowed by setting the number of worker processes to 0 (often necessary under Windows)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RandomGenerator(MapTransform):\n",
    "    \"\"\"Generates a dictionary containing image and\n",
    "    segmentation images from a given seed value.\"\"\"\n",
    "\n",
    "    def __call__(self, seed):\n",
    "        rs = np.random.RandomState(seed)\n",
    "        im, seg = create_test_image_2d(256, 256, num_seg_classes=1, random_state=rs)\n",
    "\n",
    "        return {self.keys[0]: im, self.keys[1]: seg}\n",
    "\n",
    "\n",
    "data = np.random.randint(0, monai.utils.MAX_SEED, 1000)\n",
    "\n",
    "trans = Compose(\n",
    "    [\n",
    "        RandomGenerator(keys=(\"im\", \"seg\")),\n",
    "        EnsureChannelFirstd(keys=(\"im\", \"seg\"), channel_dim=\"no_channel\"),\n",
    "    ]\n",
    ")\n",
    "\n",
    "train_ds = Dataset(data, trans)\n",
    "train_loader = DataLoader(train_ds, batch_size=20, shuffle=True, num_workers=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Network, loss, and optimizers defined as normal:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\")\n",
    "net = UNet(2, 1, 1, (8, 16, 32), (2, 2), num_res_units=2).to(device)\n",
    "loss_function = Dice(sigmoid=True)\n",
    "optimizer = torch.optim.Adam(net.parameters(), 1e-5)\n",
    "max_epochs = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A simple training function is defined which only performs step optimization of the network:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_step(batch):\n",
    "    inputs, labels = batch[\"im\"].to(device), batch[\"seg\"].to(device)\n",
    "\n",
    "    optimizer.zero_grad()\n",
    "    outputs = net(inputs)\n",
    "    loss = loss_function(outputs, labels)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "\n",
    "def train(use_buffer):\n",
    "    # wrap the loader in the ThreadBuffer if selected\n",
    "    src = ThreadBuffer(train_loader, 1) if use_buffer else train_loader\n",
    "\n",
    "    for _ in range(max_epochs):\n",
    "        for batch in src:\n",
    "            train_step(batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Timing how long it takes to generate a single batch versus the time taken to optimize the network for one step reveals the proportion of time taken by each during each full training iteration:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "52.9 ms ± 1.83 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
      "36.6 ms ± 2.07 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
     ]
    }
   ],
   "source": [
    "it = iter(train_loader)\n",
    "batch = next(it)\n",
    "\n",
    "%timeit -n 1 next(it)\n",
    "%timeit -n 1 train_step(batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Without using an asynchronous buffer for batch generation these operations must be sequential:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "50.7 s ± 2.35 s per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
     ]
    }
   ],
   "source": [
    "%timeit -n 1 train(False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With overlap we see a significant speedup:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "31.1 s ± 833 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
     ]
    }
   ],
   "source": [
    "%timeit -n 1 train(True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
